# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import itertools
import math
import unittest

import six
import six.moves as sm

from thrift import Thrift

import thrift.util.randomizer as randomizer

from fuzz import ttypes

class TestRandomizer(object):
    iterations = 1024

    def get_randomizer(self, ttypes, spec_args, constraints):
        state = randomizer.RandomizerState({"global_constraint" : 100})
        return state.get_randomizer(ttypes, spec_args, constraints)

class TestBoolRandomizer(unittest.TestCase, TestRandomizer):
    def test_always_true(self):
        cls = self.__class__
        constraints = {'p_true': 1.0}
        gen = self.get_randomizer(Thrift.TType.BOOL, None, constraints)

        for _ in sm.xrange(cls.iterations):
            self.assertTrue(gen.generate())

    def test_always_false(self):
        cls = self.__class__
        constraints = {'p_true': 0.0}
        gen = self.get_randomizer(Thrift.TType.BOOL, None, constraints)

        for _ in sm.xrange(cls.iterations):
            self.assertFalse(gen.generate())

    def test_seeded(self):
        cls = self.__class__
        constraints = {
            'seeds': [True],
            'p_random': 0,
            'p_fuzz': 0
        }
        gen = self.get_randomizer(Thrift.TType.BOOL, None, constraints)

        for _ in sm.xrange(cls.iterations):
            self.assertTrue(gen.generate())

    def test_int_seeded(self):
        cls = self.__class__
        constraints = {
            'seeds': [1],
            'p_random': 0,
            'p_fuzz': 0
        }
        gen = self.get_randomizer(Thrift.TType.BOOL, None, constraints)

        for _ in sm.xrange(cls.iterations):
            self.assertTrue(gen.generate())

class TestEnumRandomizer(unittest.TestCase, TestRandomizer):
    def test_always_valid(self):
        cls = self.__class__
        constraints = {'p_invalid': 0}
        gen = self.get_randomizer(Thrift.TType.I32, ttypes.Color, constraints)

        for _ in sm.xrange(cls.iterations):
            self.assertIn(gen.generate(), ttypes.Color._VALUES_TO_NAMES)

    def test_never_valid(self):
        cls = self.__class__
        constraints = {'p_invalid': 1}
        gen = self.get_randomizer(Thrift.TType.I32, ttypes.Color, constraints)

        for _ in sm.xrange(cls.iterations):
            self.assertNotIn(gen.generate(), ttypes.Color._VALUES_TO_NAMES)

    def test_choices(self):
        cls = self.__class__
        choices = ["RED", "BLUE", "BLACK"]
        constraints = {'choices': choices}
        gen = self.get_randomizer(Thrift.TType.I32, ttypes.Color, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            name = ttypes.Color._VALUES_TO_NAMES[val]
            self.assertIn(name, choices)

    def test_seeded(self):
        cls = self.__class__
        seeds = ["RED", "BLUE", "BLACK"]
        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 0
        }
        gen = self.get_randomizer(Thrift.TType.I32, ttypes.Color, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            name = ttypes.Color._VALUES_TO_NAMES[val]
            self.assertIn(name, seeds)


class TestIntRandomizer(TestRandomizer):
    @classmethod
    def _one_bit_flipped(cls, a, b):
        """Return true if a and b differ at at most one bit position"""
        diff = a ^ b  # Bits set to 1 where a and b differ
        # If diff has only one `1` bit, subtracting one will clear that bit
        # Otherwise, the most significant 1 will not be cleared
        return 0 == (diff & (diff - 1))

    @classmethod
    def _within_delta(cls, a, b, delta):
        return delta >= abs(a - b)

    @classmethod
    def _is_fuzzed_single_seed(cls, seed, fuzzed, delta):
        return (cls._one_bit_flipped(seed, fuzzed) or
                cls._within_delta(seed, fuzzed, delta))

    @classmethod
    def is_fuzzed(cls, seeds, fuzzed, delta=None):
        """Check whether `fuzzed` could have been
        generated by fuzzing any element of `seeds`"""
        if delta is None:
            # Find the default `fuzz_max_delta` constraint
            randomizer_cls = randomizer.RandomizerState().get_randomizer(
                cls.ttype, None, {})
            delta = randomizer_cls.default_constraints['fuzz_max_delta']

        return any(cls._is_fuzzed_single_seed(seed, fuzzed, delta)
                   for seed in seeds)

    @property
    def min(self):
        cls = self.__class__
        n_bits = cls.n_bits
        return -(2 ** (n_bits - 1))

    @property
    def max(self):
        cls = self.__class__
        n_bits = cls.n_bits
        return (2 ** (n_bits - 1)) - 1

    def testInRange(self):
        cls = self.__class__
        ttype = cls.ttype
        min_ = self.min
        max_ = self.max
        gen = self.get_randomizer(ttype, None, {})
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertGreaterEqual(val, min_)
            self.assertLessEqual(val, max_)

    def testConstant(self):
        cls = self.__class__
        ttype = cls.ttype

        constant = 17

        constraints = {
            'choices': [constant]
        }

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertEquals(val, constant)

    def testChoices(self):
        cls = self.__class__
        ttype = cls.ttype

        choices = [11, 17, 19]

        constraints = {'choices': choices}

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIn(val, choices)

    def testRange(self):
        cls = self.__class__
        ttype = cls.ttype

        range_ = [45, 55]

        constraints = {'range': range_}

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertGreaterEqual(val, range_[0])
            self.assertLessEqual(val, range_[1])

    def testRangeChoicePrecedence(self):
        cls = self.__class__
        ttype = cls.ttype

        range_ = [45, 55]
        choices = [11, 17, 19]

        constraints = {
            'range': range_,
            'choices': choices
        }

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIn(val, choices)

    def testSeeded(self):
        cls = self.__class__
        ttype = cls.ttype

        seeds = [11, 17, 19]

        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 0
        }

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)

    def testFuzzing(self):
        cls = self.__class__
        ttype = cls.ttype
        min_ = self.min
        max_ = self.max

        max_delta = 4
        seeds = [
            0, self.max - int(max_delta / 2), self.min + int(max_delta / 2)
        ]

        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 1,
            'fuzz_max_delta': max_delta
        }

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertGreaterEqual(val, min_)
            self.assertLessEqual(val, max_)
            self.assertTrue(cls.is_fuzzed(seeds, val, max_delta))

class TestByteRandomizer(TestIntRandomizer, unittest.TestCase):
    ttype = Thrift.TType.BYTE
    n_bits = 8

class TestI16Randomizer(TestIntRandomizer, unittest.TestCase):
    ttype = Thrift.TType.I16
    n_bits = 16

class TestI32Randomizer(TestIntRandomizer, unittest.TestCase):
    ttype = Thrift.TType.I32
    n_bits = 32

class TestI64Randomizer(TestIntRandomizer, unittest.TestCase):
    ttype = Thrift.TType.I64
    n_bits = 64


class TestFloatRandomizer(TestRandomizer):
    @property
    def randomizer_cls(self):
        return self.__class__.randomizer_cls

    def testZero(self):
        cls = self.__class__
        constraints = {
            'p_zero': 1,
            'p_unreal': 0
        }
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertEqual(val, 0.0)

    def testNonZero(self):
        cls = self.__class__
        constraints = {
            'p_zero': 0,
        }
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertNotEqual(val, 0.0)

    def testUnreal(self):
        cls = self.__class__
        constraints = {
            'p_unreal': 1
        }
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertTrue(
                math.isnan(val) or math.isinf(val)
            )

    def testReal(self):
        cls = self.__class__
        constraints = {
            'p_unreal': 0
        }
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertFalse(
                math.isnan(val) or math.isinf(val)
            )

    def testConstant(self):
        cls = self.__class__
        constant = 77.2
        constraints = {
            'mean': constant,
            'std_deviation': 0,
            'p_unreal': 0,
            'p_zero': 0
        }
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertEquals(val, constant)

    def testChoices(self):
        cls = self.__class__
        choices = [float('-inf'), 0.0, 13.37]
        constraints = {
            'choices': choices
        }
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIn(val, choices)

    def testSeeded(self):
        cls = self.__class__
        seeds = [float('-inf'), 0.0, 13.37]
        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 0
        }
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)

    def testIntSeeded(self):
        cls = self.__class__
        seeds = [1, 2, 3]
        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 0
        }
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)


class TestSinglePrecisionRandomizer(TestFloatRandomizer, unittest.TestCase):
    randomizer_cls = randomizer.SinglePrecisionFloatRandomizer
    ttype = Thrift.TType.FLOAT

class TestDoublePrecisionRandomizer(TestFloatRandomizer, unittest.TestCase):
    randomizer_cls = randomizer.DoublePrecisionFloatRandomizer
    ttype = Thrift.TType.DOUBLE

class TestStringRandomizer(TestRandomizer, unittest.TestCase):
    def testInRange(self):
        cls = self.__class__
        ascii_min, ascii_max = randomizer.StringRandomizer.ascii_range

        gen = self.get_randomizer(Thrift.TType.STRING, None, {})
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            for char in val:
                self.assertTrue(ascii_min <= ord(char) <= ascii_max)

    def testEmpty(self):
        cls = self.__class__

        constraints = {'mean_length': 0}
        gen = self.get_randomizer(Thrift.TType.STRING, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            for _char in val:
                self.assertEquals(0, len(val))

    def testChoices(self):
        cls = self.__class__

        choices = ['foo', 'bar']
        constraints = {'choices': choices}
        gen = self.get_randomizer(Thrift.TType.STRING, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIn(val, choices)

    def testSeeded(self):
        cls = self.__class__

        seeds = ['foo', 'bar']
        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 0
        }
        gen = self.get_randomizer(Thrift.TType.STRING, None, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)

class TestListRandomizer(TestRandomizer, unittest.TestCase):
    @classmethod
    def _has_extra_element(cls, shorter, longer):
        """Return True if the list `longer` can be created by inserting
        one element into `shorter`"""
        assert(len(shorter) + 1 == len(longer))
        found_inserted = False
        for i, elem in enumerate(shorter):
            if found_inserted:
                # Same element should be at greater index in `longer`
                if elem != longer[i + 1]:
                    return False
            else:
                # Should be at the same index; otherwise we have found
                # the inserted element
                if elem != longer[i]:
                    found_inserted = True
        return True

    @classmethod
    def _is_fuzzed_single_seed(cls, seed, fuzzed):
        """Check whether `fuzzed` could have been generated by fuzzing `seed`

        Requires that the elements of the list are of type i32
        """
        old_len = len(seed)
        new_len = len(fuzzed)
        len_delta = new_len - old_len
        if len_delta == -1:
            # An element was deleted. All other elements should be unaltered
            return cls._has_extra_element(fuzzed, seed)
        elif len_delta == 0:
            # An element was fuzzed
            different_elements = []
            for old, new in zip(seed, fuzzed):
                if old != new:
                    different_elements.append((old, new))
            if len(different_elements) == 0:
                # Fuzzed element was not changed
                return True
            elif len(different_elements) == 1:
                # Element was fuzzed
                seed_elem, fuzzed_elem = different_elements[0]
                return TestI32Randomizer.is_fuzzed([seed_elem], fuzzed_elem)
            else:
                return False
        elif len_delta == 1:
            # An element was inserted. All other elements should be unaltered
            return cls._has_extra_element(seed, fuzzed)
        else:
            return False

    @classmethod
    def is_fuzzed(cls, seeds, fuzzed):
        """Check whether `fuzzed` could have been generated by
        fuzzing any element of `seeds`. Requires that the elements
        of the list are of type i32"""
        return any(
            cls._is_fuzzed_single_seed(seed, fuzzed) for seed in seeds
        )

    def testEmpty(self):
        cls = self.__class__

        ttype = Thrift.TType.LIST
        spec_args = (Thrift.TType.I32, None)  # Elements are i32
        constraints = {'mean_length': 0}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertEquals(len(val), 0)

    def testMaxLength(self):
        cls = self.__class__

        ttype = Thrift.TType.LIST
        spec_args = (Thrift.TType.I32, None)  # Elements are i32
        constraints = {'mean_length': 100, 'max_length': 99}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        # Test to make sure that max length is enforced.
        #
        # Generate a lot of lists that should never be over 99 long
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertLessEqual(len(val), 99)

    def testElementConstraints(self):
        cls = self.__class__

        ttype = Thrift.TType.LIST
        spec_args = (Thrift.TType.BOOL, None)
        constraints = {'element': {'p_true': 0}}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            for elem in val:
                self.assertFalse(elem)

    def testSeeded(self):
        cls = self.__class__

        seeds = [
            [True, False, True],
            [False, False],
            []
        ]

        ttype = Thrift.TType.LIST
        spec_args = (Thrift.TType.BOOL, None)
        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 0
        }

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)

    def testFuzzed(self):
        cls = self.__class__

        seeds = [
            [],
            [1],
            [1, 2],
            [1, 2, 3],
            [1, 1, 1, 2, 2, 2, 3, 3, 3],
            [1, 1]
        ]

        ttype = Thrift.TType.LIST
        spec_args = (Thrift.TType.I32, None)
        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 1,
            'element': {'p_random': 0, 'p_fuzz': 1}
        }

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertTrue(cls.is_fuzzed(seeds, val),
                            msg="val %s not generated by fuzzing" % val)

class TestSetRandomizer(TestRandomizer, unittest.TestCase):
    def testEmpty(self):
        cls = self.__class__

        ttype = Thrift.TType.SET
        spec_args = (Thrift.TType.I32, None)  # Elements are i32
        constraints = {'mean_length': 0}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertEquals(len(val), 0)

    def testElementConstraints(self):
        cls = self.__class__

        ttype = Thrift.TType.SET
        spec_args = (Thrift.TType.BOOL, None)
        constraints = {'element': {'p_true': 0}}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            for elem in val:
                self.assertFalse(elem)

    def testSeeded(self):
        cls = self.__class__

        seeds = [
            {1, 2, 3},
            set(),
            {-1, -2, -3}
        ]

        ttype = Thrift.TType.SET
        spec_args = (Thrift.TType.I32, None)
        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 0
        }

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)

class TestMapRandomizer(TestRandomizer, unittest.TestCase):
    def testEmpty(self):
        cls = self.__class__

        ttype = Thrift.TType.MAP
        spec_args = (Thrift.TType.I32, None, Thrift.TType.I16, None)
        constraints = {'mean_length': 0}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertEquals(len(val), 0)

    def testKeyConstraints(self):
        cls = self.__class__

        ttype = Thrift.TType.MAP
        spec_args = (Thrift.TType.BOOL, None, Thrift.TType.I16, None)
        constraints = {'key': {'p_true': 0}}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            for elem in val:
                self.assertFalse(elem)

    def testValConstraints(self):
        cls = self.__class__

        ttype = Thrift.TType.MAP
        spec_args = (Thrift.TType.I32, None, Thrift.TType.I32, ttypes.Color)
        constraints = {'value': {'p_invalid': 0}}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            for elem in six.itervalues(val):
                self.assertIn(elem, ttypes.Color._VALUES_TO_NAMES)

    def testSeeded(self):
        cls = self.__class__

        seeds = [
            {1: "foo", 2: "fee", 3: "fwee"},
            {},
            {0: ""}
        ]

        ttype = Thrift.TType.MAP
        spec_args = (Thrift.TType.I32, None, Thrift.TType.STRING, None)
        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 0
        }

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)


class TestStructRandomizer(TestRandomizer, unittest.TestCase):
    def get_spec_args(self, ttype):
        # (ttype, thrift_spec, is_union)
        return (ttype, ttype.thrift_spec, ttype.isUnion())

    def struct_randomizer(self, ttype=None, constraints=None):
        if ttype is None:
            ttype = self.__class__.ttype
        return self.get_randomizer(
            Thrift.TType.STRUCT,
            self.get_spec_args(ttype),
            constraints or {}
        )

    def testNoFieldConstraints(self):
        cls = self.__class__

        constraints = {
            'per_field': {
                'a': {'p_include': 0.0},
                'b': {'p_include': 0.0}
            },
            'p_include': 1.0,
        }

        gen = self.struct_randomizer(ttypes.StructWithOptionals, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNone(val.a)
            self.assertIsNone(val.b)

    def testOneTrueFieldConstraints(self):
        cls = self.__class__

        constraints = {
            'per_field': {
                'b': {'p_include': 0.0}
            },
            'p_include': 1.0,
            'a': {'p_true': 1.0},
        }

        gen = self.struct_randomizer(ttypes.StructWithOptionals, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertTrue(val.a)
            self.assertIsNone(val.b)

    def testStructSeed(self):
        cls = self.__class__
        constraints = {
            '|Rainbow': {
                "seeds": [ttypes.Rainbow(colors=[ttypes.Color.RED])],
            },
            '|NestedStructs': {
                'seeds': [ttypes.NestedStructs(rainbow=ttypes.Rainbow(colors=[ttypes.Color.ORANGE]))],
            },
        }
        gen = self.struct_randomizer(ttypes.NestedStructs, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)

    def testNestedConstraints(self):
        cls = self.__class__

        constraints = {
            'per_field': {
                'b': {'p_include': 0.0}
            },
            'p_include': 1.0,
            'a': {'p_true': 1.0},
            'c': {
                'p_include': 1.0,
                'per_field': {
                    'b': {'p_include': 0.0}
                },
            }
        }

        gen = self.struct_randomizer(ttypes.StructWithOptionals, constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertTrue(val.a)
            self.assertIsNone(val.b)
            self.assertIsNotNone(val.c.a)
            self.assertIsNone(val.c.b)

    def testEmptyUnion(self):
        cls = self.__class__
        constraints = {}
        gen = self.struct_randomizer(ttypes.EmptyUnion, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            # Because the enum has no valid fields it's hard to generate
            # a reasonable value. So returning None from the
            # randomizer is reasonable.
            self.assertIsNone(val)

    def testStructContainingDefaultUnion(self):
        cls = self.__class__
        constraints = {}
        gen = self.struct_randomizer(ttypes.NumberUnionStruct, constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)

    def testSubRanomizersHaveDefaults(self):

        # Have constraints with a field that is never
        # used and won't come from the existing defaults.
        # Then make sure that constraint for StructWithOptionals
        # doesn't go everywhere.
        targeted_key = "targeted_constraint"
        constraints = {targeted_key: 100}
        gen = self.struct_randomizer(ttypes.StructWithOptionals, constraints)

        # Yes this is pretty ugly as we have
        # to reach into a undercode field, but it
        # does show what constraints get propagated and what don't
        for _field_name, data in gen._field_rules.items():
            # The targeted key will only apply to the first randomizer
            self.assertNotIn(targeted_key, data["randomizer"].constraints.keys())

            # The global one is propagated to everything.
            self.assertIn("global_constraint", data["randomizer"].constraints)


class TestUnionRandomizer(TestStructRandomizer, unittest.TestCase):
    ttype = ttypes.IntUnion

    def testAlwaysInclude(self):
        cls = self.__class__
        constraints = {'p_include': 1}

        gen = self.struct_randomizer(constraints=constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            # Check that field is nonzero, indicating a field is set
            self.assertNotEqual(val.field, 0)

    def testNeverInclude(self):
        cls = self.__class__
        constraints = {'p_include': 0}

        gen = self.struct_randomizer(constraints=constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            # Check that field is zero, indicating no fields are set
            self.assertIsNone(
                val,
                (
                    "Because there's no way to add fields of a "
                    "union there should be no way to create the union."
                ),
            )

    def testSeededFuzz(self):
        cls = self.__class__
        seeds = [ttypes.IntUnion(a=20), ttypes.IntUnion(b=40)]
        constraints = {"seeds": seeds, "p_random": 0}
        gen = self.struct_randomizer(constraints=constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(
                val,
                (
                    "The union should always be created. "
                    "We don't know the expected values, "
                    "just that they exist"
                ),
            )


    def testSeeded(self):
        cls = self.__class__

        seeds = [
            {'a': 2},
            {'b': 4},
            ttypes.IntUnion(a=2),
            ttypes.IntUnion(b=4),
        ]

        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 0
        }

        gen = self.struct_randomizer(constraints=constraints)

        def is_seed(val):
            if val.field == 1:
                return val.value == 2
            elif val.field == 2:
                return val.value == 4
            return False

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertTrue(is_seed(val),
                            msg="Not a seed: %s (%s)" % (val, val.__dict__))

class TestListStructRandomizer(TestStructRandomizer, unittest.TestCase):
    """
    Verify that this struct type is generated correctly

    struct ListStruct {
        1: list<bool> a;
        2: list<i16> b;
        3: list<double> c;
        4: list<string> d;
        5: list<list<i32>> e;
        6: list<map<i32, i32>> f;
        7: list<set<string>> g;
    }
    """

    ttype = ttypes.ListStruct

    def testGeneration(self):
        gen = self.struct_randomizer()
        val = gen.generate()

        if val.a is not None:
            self.assertIsInstance(val.a, list)
            for elem in val.a:
                self.assertIsInstance(elem, bool)

        if val.b is not None:
            self.assertIsInstance(val.b, list)
            for elem in val.b:
                self.assertIsInstance(elem, int)

        if val.c is not None:
            self.assertIsInstance(val.c, list)
            for elem in val.c:
                self.assertIsInstance(elem, float)

        if val.d is not None:
            self.assertIsInstance(val.d, list)
            for elem in val.d:
                self.assertIsInstance(elem, six.string_types)

        if val.e is not None:
            self.assertIsInstance(val.e, list)
            for elem in val.e:
                self.assertIsInstance(elem, list)
                for sub_elem in elem:
                    self.assertIsInstance(sub_elem, int)

        if val.f is not None:
            self.assertIsInstance(val.f, list)
            for elem in val.f:
                self.assertIsInstance(elem, dict)
                for k, v in six.iteritems(elem):
                    self.assertIsInstance(k, int)
                    self.assertIsInstance(v, int)

        if val.g is not None:
            self.assertIsInstance(val.g, list)
            for elem in val.g:
                self.assertIsInstance(elem, set)
                for sub_elem in elem:
                    self.assertIsInstance(sub_elem, six.string_types)

    def testFieldConstraints(self):
        cls = self.__class__

        constraints = {
            'p_include': 1.0,
            'a': {'element': {'p_true': 1.0}},
            'b': {'mean_length': 0.0},
            'd': {'element': {'mean_length': 0.0}},
            'e': {'element': {'mean_length': 0.0}},
        }

        gen = self.struct_randomizer(constraints=constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNotNone(val.a)
            self.assertIsNotNone(val.b)
            self.assertIsNotNone(val.c)
            self.assertIsNotNone(val.d)
            self.assertIsNotNone(val.e)
            self.assertIsNotNone(val.f)
            self.assertIsNotNone(val.g)

            for elem in val.a:
                self.assertTrue(elem)

            self.assertEquals(len(val.b), 0)

            for elem in val.d:
                self.assertEquals(len(elem), 0)

            for elem in val.e:
                self.assertEquals(len(elem), 0)

    def testSeeded(self):
        seeds = [{
            'a': [True, False, False],
            'b': [1, 2, 3],
            'c': [1.2, 2.3],
            'd': ["foo", "bar"],
            'e': [[]],
            'f': [{1: 2}, {3: 4}],
            'g': [{"foo", "bar"}]
        }]

        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 0
        }

        gen = self.struct_randomizer(constraints=constraints)
        val = gen.generate()
        for key, expected in six.iteritems(seeds[0]):
            self.assertEqual(expected, getattr(val, key, None),
                             msg="%s, %s" % (val, dir(val)))

    def testFuzz(self):
        cls = self.__class__

        seeds = [{
            'a': [True, False, False],
            'b': [1, 2, 3],
            'c': [1.2, 2.3],
            'd': ["foo", "bar"],
            'e': [[]],
            'f': [{1: 2}, {3: 4}],
            'g': [{"foo", "bar"}]
        }]

        constraints = {
            'seeds': seeds,
            'p_random': 0,
            'p_fuzz': 1
        }

        gen = self.struct_randomizer(constraints=constraints)
        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            n_different = 0
            for key, expected in six.iteritems(seeds[0]):
                if expected != getattr(val, key, None):
                    n_different += 1
            self.assertLessEqual(n_different, 1)

    def testBoolTypeConstraints(self):
        cls = self.__class__

        constraints = {
            'p_include': 1.0,
            '|bool': {'p_true': 1.0}
        }

        gen = self.struct_randomizer(constraints=constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNotNone(val.a)
            for elem in val.a:
                self.assertTrue(elem)

    def testDoubleTypeConstraints(self):
        cls = self.__class__
        constant = 11.1

        constraints = {
            'p_include': 1.0,
            '|double': {
                'mean': constant,
                'std_deviation': 0,
                'p_unreal': 0,
                'p_zero': 0
            }
        }

        gen = self.struct_randomizer(constraints=constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNotNone(val.c)
            for elem in val.c:
                self.assertEquals(elem, constant)

    def testListTypeConstraints(self):
        """
        We have an |i32 rule and a |list<i32> rule,
        The |list<i32> rule is more specific and overrides the |i32 rule
        """
        cls = self.__class__

        int_choices = [1, 2, 3]
        list_int_choices = [4, 5, 6]

        constraints = {
            'p_include': 1.0,
            '|list<i32>': {
                'element': {
                    'choices': list_int_choices
                }
            },
            '|i32': {
                'choices': int_choices
            }
        }

        gen = self.struct_randomizer(constraints=constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNotNone(val.e)
            # Iterate through each i32 element in list<list<i32>>
            for elem in itertools.chain(*val.e):
                self.assertIn(elem, list_int_choices)

    def testMapTypeConstraints(self):
        """
        ListStruct.f is type list<map<i32, i32>>
        We use two rules:
        |i32
        |map<i32, i32>.key

        The |i32 rule will be used on map values, and the .key rule
        will be used on map keys since it is more specific.
        """
        cls = self.__class__

        int_choices = [1, 2, 3]
        key_int_choices = [4, 5, 6]

        constraints = {
            'p_include': 1.0,
            '|map<i32, i32>': {
                'key': {
                    'choices': key_int_choices
                }
            },
            '|i32': {
                'choices': int_choices
            }
        }

        gen = self.struct_randomizer(constraints=constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNotNone(val.f)
            for key in itertools.chain(*val.f):
                self.assertIn(key, key_int_choices)
            for val in itertools.chain(*map(dict.values, val.f)):
                self.assertIn(val, int_choices)


class TestNestedStruct(TestStructRandomizer, unittest.TestCase):
    ttype = ttypes.NestedStructs

    def testFuzz(self):
        """Check that only one subfield is randomized"""
        cls = self.__class__

        seed = {
            'ls': {
                'a': [True, True],
                'b': [1, 2, 3],
                'c': [1.2, 3.4],
                'd': ["foo"],
                'e': [[1], [2, 3]],
                'f': [{1: -1}, {-1: 1}],
                'g': [{"a", "b"}, set(), set()]
            },
            'rainbow': {
                'colors': ["YELLOW", "YELLOW", "GRAY"],
                'brightness': float('inf')
            },
            'ints': {
                'a': 1000
            }
        }

        constraints = {
            'seeds': [seed],
            'p_random': 0,
            'p_fuzz': 1,
            'ls': {'p_random': 0},
            'rainbow': {'p_random': 0},
            'ints': {'p_random': 0}
        }

        gen = self.struct_randomizer(constraints=constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()

            # To check the generated value against the seed,
            # convert enum numbers back into enum names
            val.rainbow.colors = [
                ttypes.Color._VALUES_TO_NAMES.get(c, c)
                for c in val.rainbow.colors
            ]

            n_different = 0
            differents = []

            # Compare actual (generated) struct to expected structs
            for struct_name in {'ls', 'rainbow'}:
                actual_struct = getattr(val, struct_name)
                for key, expected_field in six.iteritems(seed[struct_name]):
                    actual_field = getattr(actual_struct, key, None)
                    if expected_field != actual_field:
                        n_different += 1
                        differents.append((
                            '%s.%s' % (struct_name, key),
                            expected_field,
                            actual_field
                        ))

            # Fuzzed union should use the same field as the seed union (a=1)
            self.assertEqual(val.ints.field, 1)

            expected = seed['ints']['a']
            actual = val.ints.value
            if actual != expected:
                n_different += 1
                differents.append(('ints.a', expected, actual))

            message = ', '.join('%s: %s != %s' % diff for diff in differents)
            self.assertLessEqual(n_different, 1, msg=message)

class TestStructRecursion(TestStructRandomizer, unittest.TestCase):
    ttype = ttypes.BTree

    def max_depth(self, tree):
        """Find the maximum depth of a ttypes.BTree struct"""
        if tree is None:
            return 0

        if isinstance(tree, ttypes.BTreeBranch):
            tree = tree.child

        child_depths = [self.max_depth(child) for child in tree.children]
        if child_depths:
            max_child_depth = max(child_depths)
        else:
            max_child_depth = 0

        return 1 + max_child_depth

    def testDepthZero(self):
        constraints = {'max_recursion_depth': 0}

        gen = self.struct_randomizer(constraints=constraints)

        if not gen.constraints['max_recursion_depth'] == 0:
            raise ValueError('Invalid recursion depth %d' % (
                gen.constraints['max_recursion_depth']))

        val = gen.generate()

        self.assertEquals(self.max_depth(val), 0)

    def testDepthOne(self):
        cls = self.__class__
        constraints = {'max_recursion_depth': 1}

        gen = self.struct_randomizer(constraints=constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertLessEqual(self.max_depth(val), 1)

    def testDepthTwo(self):
        cls = self.__class__
        constraints = {'max_recursion_depth': 2}

        gen = self.struct_randomizer(constraints=constraints)

        for _ in sm.xrange(cls.iterations):
            val = gen.generate()
            self.assertLessEqual(self.max_depth(val), 2)

if __name__ == '__main__':
    unittest.main()
